import numpy

from chainer_chemistry.dataset.splitters.base_splitter import BaseSplitter


class TimeSplitter(BaseSplitter):
    """Class for doing time order splits."""

    def _split(self, dataset, frac_train=0.8, frac_valid=0.1, frac_test=0.1,
               **kwargs):
        numpy.testing.assert_almost_equal(
            frac_train + frac_valid + frac_test, 1.)

        time_list = kwargs.get('time_list')

        train_cutoff = int(frac_train * len(dataset))
        valid_cutoff = int((frac_train + frac_valid) * len(dataset))

        index = [idx for idx, _ in sorted(
            enumerate(time_list), key=lambda x: x[1])][:len(dataset)]

        train_index = index[:train_cutoff]
        valid_index = index[train_cutoff:valid_cutoff]
        test_index = index[valid_cutoff:]

        return numpy.array(train_index), numpy.array(valid_index), \
            numpy.array(test_index)

    def train_valid_test_split(self, dataset, time_list=None, frac_train=0.8,
                               frac_valid=0.1, frac_test=0.1, converter=None,
                               return_index=True, **kwargs):
        """Split dataset into train, valid and test set.

        Split indices are generated by splitting based on time order.

        Args:
            dataset(NumpyTupleDataset, numpy.ndarray):
                Dataset.
            time_list(list):
                Time list corresponding to dataset.
            frac_train(float):
                Fraction of dataset put into training data.
            frac_valid(float):
                Fraction of dataset put into validation data.
            frac_test(float):
                Fraction of dataset put into test data.
            converter(callable):
            return_index(bool):
                If `True`, this function returns only indexes. If `False`, this
                function returns splitted dataset.

        Returns:
            SplittedDataset(tuple): splitted dataset or indices

        .. admonition:: Example
            >>> from chainer_chemistry.datasets import NumpyTupleDataset
            >>> from chainer_chemistry.dataset.splitters import TimeSplitter
            >>> a = numpy.random.random((10, 10))
            >>> b = numpy.random.random((10, 8))
            >>> c = numpy.random.random((10, 1))
            >>> d = NumpyTupleDataset(a, b, c)
            >>> splitter = TimeSplitter()
            >>> train, valid, test =
                    splitter.train_valid_test_split(dataset,
                                                    return_index=False)
            >>> print(len(train), len(valid))
            8, 1, 1
        """
        return super(TimeSplitter, self).train_valid_test_split(
            dataset, frac_train, frac_valid, frac_test, converter,
            return_index, time_list=time_list, **kwargs)

    def train_valid_split(self, dataset, time_list=None, frac_train=0.9,
                          frac_valid=0.1, converter=None, return_index=True,
                          **kwargs):
        """Split dataset into train and valid set.

        Split indices are generated by splitting based on time order.

        Args:
            dataset(NumpyTupleDataset, numpy.ndarray):
                Dataset.
            time_list(list):
                Time list corresponding to dataset.
            frac_train(float):
                Fraction of dataset put into training data.
            frac_valid(float):
                Fraction of dataset put into validation data.
            converter(callable):
            return_index(bool):
                If `True`, this function returns only indexes. If `False`, this
                function returns splitted dataset.

        Returns:
            SplittedDataset(tuple):
                splitted dataset or indexes

        .. admonition:: Example
            >>> from chainer_chemistry.datasets import NumpyTupleDataset
            >>> from chainer_chemistry.dataset.splitters import TimeSplitter
            >>> a = numpy.random.random((10, 10))
            >>> b = numpy.random.random((10, 8))
            >>> c = numpy.random.random((10, 1))
            >>> d = NumpyTupleDataset(a, b, c)
            >>> splitter = TimeSplitter()
            >>> train, valid =
                    splitter.train_valid_split(dataset, return_index=False)
            >>> print(len(train), len(valid))
            9, 1
        """
        return super(TimeSplitter, self).train_valid_split(
            dataset, frac_train, frac_valid, converter, return_index,
            time_list=time_list, **kwargs)
